{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0",
   "metadata": {},
   "outputs": [],
   "source": [
    "K = 2\n",
    "nearest_partition = np.argpartition(dist_sq_1, K + 1, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(X_1[:, 0], X_1[:, 1], s=100)\n",
    "K = 2\n",
    "for i_1 in range(X_1.shape[0]):\n",
    "    for j in nearest_partition[i_1, :K + 1]:\n",
    "        plt.plot(*zip(X_1[j], X_1[i_1]), color='black')"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 5
}